import copy
import math
import os
import numpy as np
import torch
import numpy as np
import torch
from scipy.sparse import csr_matrix

from controlsnr import find_a_given_snr, solve_ab


def simple_collate_fn(batch):
    """
    拼接 batch，适用于所有图大小一致的情况。
    返回：
      - adj: [B, N, N]
      - labels: [B, N]
    """
    adjs = [torch.tensor(sample['adj'].toarray(), dtype=torch.float32) for sample in batch]
    labels = [torch.tensor(sample['labels'], dtype=torch.long) for sample in batch]

    adj_batch = torch.stack(adjs)       # [B, N, N]
    label_batch = torch.stack(labels)   # [B, N]

    return {
        'adj': adj_batch,
        'labels': label_batch
    }


# ---- Config ----
per_cell_tr = 4
per_cell_v = 1

snr_train = np.logspace(np.log10(0.5), np.log10(3), 15)
# gamma: fixed 4 pts
gamma_train = np.array([0.30, 1.20, 3.00, 4.00 ,5.00])

# Validation_set
snr_mid = np.sqrt(snr_train[:-1] * snr_train[1:])
# 从中均匀挑 10 个
idx = np.linspace(0, len(snr_mid) - 1, 10, dtype=int)
snr_val = snr_mid[idx]
gamma_val = gamma_train.copy()

# 可选：快速检查验证点与训练点不重合
assert set(snr_val).isdisjoint(snr_train)
# assert set(C_val).isdisjoint(C_train)

# —— 测试集（=200）
snr_test = (0.60,)
gamma_test = (0.15,)
# C_test = (10.0,)
per_cell_te = 1

class Generator(object):
    def __init__(self, N_train=50, N_test=100, N_val = 50,generative_model='SBM_multiclass', p_SBM=0.8, q_SBM=0.2, n_classes=2, path_dataset='dataset',
                 num_examples_train=100, num_examples_test=10, num_examples_val=10):
        self.N_train = N_train
        self.N_test = N_test
        self.N_val = N_val

        self.generative_model = generative_model
        self.p_SBM = p_SBM
        self.q_SBM = q_SBM
        self.n_classes = n_classes
        self.path_dataset = path_dataset

        self.data_train = None
        self.data_test = None
        self.data_val = None

        self.num_examples_train = num_examples_train
        self.num_examples_test = num_examples_test
        self.num_examples_val = num_examples_val

        # 初始化时直接生成
        self.C_train = self._make_C_grid(self.n_classes)
        # Validation_set
        c_mid = np.sqrt(self.C_train[:-1] * self.C_train[1:])
        # 从中均匀挑 10 个
        idx = np.linspace(0, len(c_mid) - 1, 10, dtype=int)
        self.C_val = c_mid[idx]
        self.C_test = (10,)  # 你可以换成别的逻辑


    def compute_C_bounds(self, k, margin=0.0):
        """返回 (C_min, C_max)"""
        snr_max = 3
        # 下界
        C_min = k * snr_max * (1.0 + float(margin))
        C_max = C_min + k * 6
        return C_min, C_max

    def _make_C_grid(self, k):
        """生成 log-uniform 网格"""
        C_min, C_max = self.compute_C_bounds(k)
        if C_max <= C_min:
            return [C_min]
        return np.exp(np.linspace(np.log(C_min), np.log(C_max), 15))

    def SBM(self, p, q, N):
        W = np.zeros((N, N))

        p_prime = 1 - np.sqrt(1 - p)
        q_prime = 1 - np.sqrt(1 - q)

        n = N // 2

        W[:n, :n] = np.random.binomial(1, p, (n, n))
        W[n:, n:] = np.random.binomial(1, p, (N-n, N-n))
        W[:n, n:] = np.random.binomial(1, q, (n, N-n))
        W[n:, :n] = np.random.binomial(1, q, (N-n, n))
        W = W * (np.ones(N) - np.eye(N))
        W = np.maximum(W, W.transpose())

        perm = torch.randperm(N).numpy()
        blockA = perm < n
        labels = blockA * 2 - 1

        W_permed = W[perm]
        W_permed = W_permed[:, perm]
        return W_permed, labels


    def SBM_multiclass(self, p, q, N, n_classes):

        p_prime = 1 - np.sqrt(1 - p)
        q_prime = 1 - np.sqrt(1 - q)

        prob_mat = np.ones((N, N)) * q_prime

        n = N // n_classes  # 基础类别大小
        remainder = N % n_classes  # 不能整除的剩余部分
        n_last = n + remainder  # 最后一类的大小

        # 先对整除部分进行块状分配
        for i in range(n_classes - 1):  # 处理前 n_classes-1 类
            prob_mat[i * n: (i + 1) * n, i * n: (i + 1) * n] = p_prime

        # 处理最后一类
        start_idx = (n_classes - 1) * n  # 最后一类的起始索引
        prob_mat[start_idx: start_idx + n_last, start_idx: start_idx + n_last] = p_prime

        # 生成邻接矩阵
        W = np.random.rand(N, N) < prob_mat
        W = W.astype(int)

        W = W * (np.ones(N) - np.eye(N))  # 移除自环
        W = np.maximum(W, W.transpose())  # 确保无向图

        # 随机打乱节点顺序
        perm = torch.randperm(N).numpy()

        # 生成类别标签
        labels =np.minimum((perm // n) , n_classes - 1)

        W_permed = W[perm]
        W_permed = W_permed[:, perm]

        #计算P矩阵的特征向量
        prob_mat_permed = prob_mat[perm][:, perm]
        # np.fill_diagonal(prob_mat_permed, 0)  # 去除自环

        eigvals, eigvecs = np.linalg.eigh(prob_mat_permed)
        idx = np.argsort(eigvals)[::-1]
        eigvecs_top = eigvecs[:, idx[:n_classes]]

        return W_permed, labels, eigvecs_top  # 返回前n_classes特征向量

    def imbalanced_SBM_multiclass(self, p, q, N, n_classes, class_sizes):

        # 上三角采样不会放大概率，直接用目标 p, q
        p_prime = float(p)
        q_prime = float(q)

        # 构造期望矩阵（块内 p，块间 q），无自环
        prob_mat = np.full((N, N), q_prime, dtype=float)
        boundaries = np.cumsum([0] + class_sizes)
        for i in range(n_classes):
            start, end = boundaries[i], boundaries[i + 1]
            prob_mat[start:end, start:end] = p_prime
        np.fill_diagonal(prob_mat, 0.0)

        # —— 关键修改：只采样上三角，然后镜像 —— #
        W = np.zeros((N, N), dtype=np.uint8)
        iu, ju = np.triu_indices(N, k=1)
        W[iu, ju] = (np.random.rand(iu.size) < prob_mat[iu, ju]).astype(np.uint8)
        W = (W + W.T).astype(np.uint8)  # 无向化；对角仍为 0

        # 打乱节点顺序
        perm = torch.randperm(N).numpy()

        # 生成并置乱标签
        labels = np.zeros(N, dtype=int)
        for i in range(n_classes):
            start, end = boundaries[i], boundaries[i + 1]
            labels[start:end] = i
        labels = labels[perm]

        # 同步置乱矩阵
        W_permed = W[perm][:, perm]

        return W_permed, labels

    def prepare_data(self):
        def get_npz_dataset(path, mode, *, snr_grid, gamma_grid, C_grid, per_cell, min_size=50, base_seed=0):
            if not os.path.exists(path):
                os.makedirs(path)
                print(f"[创建数据集] {mode} 数据目录不存在，已新建：{path}")

            npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            if not npz_files:
                print(f"[创建数据集] {mode} 数据未找到，开始生成...")
                self.create_dataset_grid(
                    path, mode=mode,
                    snr_grid=snr_grid,
                    gamma_grid=gamma_grid,
                    C_grid=C_grid,
                    per_cell=per_cell,
                    min_size=min_size,
                    base_seed=base_seed
                )
                npz_files = sorted([f for f in os.listdir(path) if f.endswith(".npz")])
            else:
                print(f"[读取数据] {mode} 集已存在，共 {len(npz_files)} 张图：{path}")
            return [os.path.join(path, f) for f in npz_files]

        # ==== 目录 ====
        train_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gstr{self.N_train}_numtr{self.num_examples_train}"
        test_dir = f"{self.generative_model}_nc{self.n_classes}_rand_gste{self.N_test}_numte{self.num_examples_test}"
        val_dir = f"{self.generative_model}_nc{self.n_classes}_rand_val{self.N_val}_numval{self.num_examples_val}"

        train_path = os.path.join(self.path_dataset, train_dir)
        test_path = os.path.join(self.path_dataset, test_dir)
        val_path = os.path.join(self.path_dataset, val_dir)

        # ==== 采用上面的三套参数 ====
        self.data_train = get_npz_dataset(
            train_path, 'train',
            snr_grid=snr_train, gamma_grid=gamma_train, C_grid=self.C_train, per_cell=per_cell_tr,
            min_size=10, base_seed=123
        )
        self.data_val = get_npz_dataset(
            val_path, 'val',
            snr_grid=snr_val, gamma_grid=gamma_val, C_grid=self.C_val, per_cell=per_cell_v,
            min_size=10, base_seed=2025
        )
        self.data_test = get_npz_dataset(
            test_path, 'test',
            snr_grid=snr_test, gamma_grid=gamma_test, C_grid=self.C_test, per_cell=per_cell_te,
            min_size=10, base_seed=31415
        )


    def sample_single(self, i, is_training=True):
        if is_training:
            dataset = self.data_train
        else:
            dataset = self.data_test
        example = dataset[i]
        if (self.generative_model == 'SBM_multiclass'):
            W_np = example['W']
            labels = np.expand_dims(example['labels'], 0)
            labels_var = torch.from_numpy(labels)
            if is_training:
                labels_var.requires_grad = True
            return W_np, labels_var


    def sample_otf_single(self, is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)
        elif self.generative_model == 'SBM_multiclass':
            W, labels,eigvecs_top = self.SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes)
        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top

    def imbalanced_sample_otf_single(self, class_sizes , is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)
        elif self.generative_model == 'SBM_multiclass':
            W, labels,eigvecs_top = self.imbalanced_SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes, class_sizes)
        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)

        W = np.expand_dims(W, 0)
        W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top

    def random_sample_otf_single(self, C = 10 ,is_training=True, cuda=True):
        if is_training:
            N = self.N_train
        else:
            N = self.N_test
        if self.generative_model == 'SBM':
            W, labels = self.SBM(self.p_SBM, self.q_SBM, N)

        elif self.generative_model == 'SBM_multiclass':
            a_low, b_low = find_a_given_snr(0.1, self.n_classes, C)
            a_high, b_high = find_a_given_snr(1, self.n_classes, C)

            lower_bound = a_low / b_low
            upper_bound = a_high / b_high

            if lower_bound > upper_bound:
                lower_bound, upper_bound = upper_bound, lower_bound

            p, q, class_sizes, snr = self.random_imbalanced_SBM_generator_balanced_sampling(
                N=N,
                n_classes=self.n_classes,
                C=C,
                alpha_range=(lower_bound, upper_bound),
                min_size= 20
            )
            W, labels,eigvecs_top = self.imbalanced_SBM_multiclass(p, q, N, self.n_classes, class_sizes)

        else:
            raise ValueError('Generative model {} not supported'.format(self.generative_model))

        labels = np.expand_dims(labels, 0)
        labels = torch.from_numpy(labels)
        W = np.expand_dims(W, 0)
        # W = torch.tensor(W, dtype=torch.float32)  # 不加 requires_grad

        return W, labels, eigvecs_top, snr, class_sizes


    def random_imbalanced_SBM_generator_balanced_sampling(self, N, n_classes, C, *,
                                        alpha_range=(1.3, 2.8),
                                        min_size=5):
        """
        随机生成 SBM 模型的参数，社区大小为随机比例但总和为 N。
        返回 p, q, class_sizes, a, b, snr。
        """
        assert N >= min_size * n_classes

        # Step 1: 随机生成 a > b，使得 a + (k - 1) * b = C
        alpha = np.random.uniform(*alpha_range)
        b = C / (alpha + (n_classes - 1))
        a = alpha * b

        # Step 2: 计算边连接概率
        logn = np.log(N)
        p = a * logn / N
        q = b * logn / N

        # ✅ Step 3: 使用 Dirichlet 生成 class_sizes
        remaining = N - min_size * n_classes
        probs = np.random.dirichlet(np.ones(n_classes))  # 总和为1的概率向量
        extras = np.random.multinomial(remaining, probs)
        class_sizes = [min_size + e for e in extras]

        # Step 4: 计算 SNR
        snr = (a - b) ** 2 / (n_classes * (a + (n_classes - 1) * b))

        return p, q, class_sizes, snr

    def imbalanced_SBM_multiclass_B(self, N, n_classes, class_sizes, B_prob, rng=None, shuffle=True):
        """
        用块概率矩阵 B_prob 生成一张 SBM 图
        输入:
          - N: 节点数
          - n_classes: 社区数
          - class_sizes: 各社区大小
          - B_prob: (k x k) 块概率矩阵，已含 logn/n 缩放
          - shuffle: 是否打乱节点顺序 (默认 True)
        返回:
          - W_dense: N×N 邻接矩阵
          - labels: 节点的社区标签 (打乱后顺序一致)
        """
        import numpy as np
        rng = np.random.default_rng() if rng is None else rng

        # 构造标签
        labels = np.zeros(N, dtype=int)
        start = 0
        for r, size in enumerate(class_sizes):
            labels[start:start + size] = r
            start += size

        # === 新增: shuffle 节点顺序 ===
        if shuffle:
            perm = rng.permutation(N)
            labels = labels[perm]
        else:
            perm = np.arange(N)

        # 初始化邻接矩阵
        W_dense = np.zeros((N, N), dtype=np.float32)

        # 遍历上三角，按 B_prob 采样
        for i in range(N):
            for j in range(i + 1, N):
                r, s = labels[i], labels[j]
                prob = B_prob[r, s]
                if rng.random() < prob:
                    W_dense[i, j] = 1.0
                    W_dense[j, i] = 1.0  # 对称化

        return W_dense, labels

    def _sample_class_sizes_dirichlet(self, N, n_classes, gamma, min_size, rng,
                                      gamma_jitter=0.5):
        """按 Dirichlet(gamma) 采样类比例 + min_size 下界，支持 gamma 抖动。"""
        assert N >= min_size * n_classes
        remaining = N - min_size * n_classes

        # 对 gamma 做整体抖动：γ' = γ * U(1-δ, 1+δ)
        if gamma_jitter and gamma_jitter > 0:
            mult = rng.uniform(1.0 - gamma_jitter, 1.0 + gamma_jitter)
            gamma_used = gamma * mult
        else:
            gamma_used = gamma

        probs = rng.dirichlet(np.full(n_classes, gamma_used, dtype=float))
        extras = rng.multinomial(remaining, probs)
        return [min_size + int(e) for e in extras]  # 也可以返回实际使用的 γ

    def gen_one_sbm_by_targets(
            self, N, n_classes, C, target_snr, gamma, min_size=5, *, rng=None,
            heterophily=False, hetero_prob=None,
            ab_jitter=0.05, keep_assortativity=True,
            pq_jitter=(0.05, 0.05),  # 必须传元组: (pj, qj)
            C_jitter=0.1, C_jitter_mode='relative',
            b_floor=1e-6
    ):
        """
        生成 SBM 的块概率矩阵 B（每个对角/非对角元素都可独立抖动）
        pq_jitter: (pj, qj)，分别控制对角/非对角的 ±比例
        """
        import numpy as np
        rng = np.random.default_rng() if rng is None else rng

        # === 0) C 抖动 ===
        C_used = float(C)
        if C_jitter and C_jitter > 0:
            if C_jitter_mode == 'relative':
                C_used *= rng.uniform(1.0 - C_jitter, 1.0 + C_jitter)
            elif C_jitter_mode == 'absolute':
                C_used += rng.uniform(-C_jitter, C_jitter)
            C_used = max(C_used, 1e-6)

        # === 1) 类大小 ===
        class_sizes = self._sample_class_sizes_dirichlet(N, n_classes, gamma, min_size, rng)

        # === 2) 解 a,b ===
        a, b = find_a_given_snr(target_snr, n_classes, C)

        # === 3) 异配/同配切换 ===
        if hetero_prob is not None:
            is_hetero = bool(rng.random() < float(hetero_prob))
        else:
            is_hetero = bool(heterophily)
        if is_hetero:
            a, b = b, a

        # === 4) 对 a,b 各自 jitter ===
        if ab_jitter and ab_jitter > 0:
            a *= rng.uniform(1.0 - ab_jitter, 1.0 + ab_jitter)
            b *= rng.uniform(1.0 - ab_jitter, 1.0 + ab_jitter)
        a, b = max(a, b_floor), max(b, b_floor)

        # === 5) 构造块概率矩阵 ===
        logn = np.log(N)
        scale = logn / N
        B_prob = np.zeros((n_classes, n_classes), dtype=float)

        pj, qj = pq_jitter  # 必须是 (pj, qj)

        # 对角线：每个 p_rr 独立 jitter
        for r in range(n_classes):
            factor = rng.uniform(1.0 - pj, 1.0 + pj)
            B_prob[r, r] = a * scale * factor

        # 非对角线：每个 q_rs 独立 jitter
        for r in range(n_classes):
            for s in range(r + 1, n_classes):
                factor = rng.uniform(1.0 - qj, 1.0 + qj)
                val = b * scale * factor
                B_prob[r, s] = B_prob[s, r] = val

        # === 6) 保持同配 ===
        if keep_assortativity and not is_hetero:
            diag_vals = np.diag(B_prob).copy()
            for r in range(n_classes):
                for s in range(n_classes):
                    if r != s and B_prob[r, s] >= diag_vals[r]:
                        B_prob[r, s] = max(1e-12, diag_vals[r] * 0.9)

        return B_prob, class_sizes, a, b, gamma, is_hetero, C_used


    def create_dataset_grid(self, directory, mode='train', *,
                            snr_grid=(0.6, 0.9, 1.1, 1.3, 1.6, 2.0, 2.5, 3.0),
                            gamma_grid=(0.15, 0.3, 0.6, 1.0, 2.0),
                            C_grid=(10.0,),  # 如需多密度可设为 (6, 10, 14)
                            per_cell=20,
                            min_size=5,
                            base_seed=0):
        """
        在 (SNR × gamma × C) 的笛卡尔网格上生成数据；每个网格点生成 per_cell 张图。
        文件名写入网格信息；.npz 中保存全部元数据，保证复现性与分析便利。
        """
        os.makedirs(directory, exist_ok=True)

        if mode == 'train':
            N = self.N_train
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_train = directory
        elif mode == 'val':
            N = self.N_val
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_val = directory
        elif mode == 'test':
            N = self.N_test
            num_graphs_expected = len(snr_grid) * len(gamma_grid) * len(C_grid) * per_cell
            self.data_test = directory
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        idx = 0
        for c_idx, C in enumerate(C_grid):
            for s_idx, snr_target in enumerate(snr_grid):
                for g_idx, gamma in enumerate(gamma_grid):
                    # 每个格点单独的 RNG，保证可复现
                    cell_seed = base_seed + (c_idx * 10_000_000
                                             + s_idx * 10_000
                                             + g_idx * 100)

                    rng = np.random.default_rng(cell_seed)

                    for rep in range(per_cell):
                        rand_N = int(N + (np.random.rand() * 2 - 1) * 500)
                        B_prob, class_sizes, a, b, gamma, is_hetero, C_used = self.gen_one_sbm_by_targets(
                            N=rand_N, n_classes=self.n_classes, C=C,
                            target_snr=snr_target, gamma=gamma,
                            min_size=min_size, rng=rng, pq_jitter=(0.02, 0.05)
                        )

                        # 生成图（你已有）
                        W_dense, labels = self.imbalanced_SBM_multiclass_B(
                            N, self.n_classes, class_sizes, B_prob
                        )

                        W_sparse = csr_matrix(W_dense)

                        fname = (f"{mode}_N{rand_N}_i{idx:05d}"
                                 f"__C{C:.2f}__snr{snr_target:.3f}"
                                 f"__g{gamma:.3f}__rep{rep:02d}.npz")
                        path = os.path.join(directory, fname)

                        np.savez_compressed(
                            path,
                            adj_data=W_sparse.data,
                            adj_indices=W_sparse.indices,
                            adj_indptr=W_sparse.indptr,
                            adj_shape=W_sparse.shape,
                            labels=labels,
                            a=a, b=b,
                            C=C,
                            snr_target=snr_target,
                            gamma=gamma,
                            class_sizes=np.array(class_sizes, dtype=np.int32),
                        )
                        idx += 1

        print(f"[{mode}] 网格数据完成: 共 {idx} 张（期望 {num_graphs_expected}）。目录: {directory}")

    def copy(self):
        return copy.deepcopy(self)